package main

import (
	"bytes"
	"database/sql"
	"encoding/json"
	"fmt"
	"go/ast"
	"io"
	"os"
	"os/exec"
	"path/filepath"
	"reflect"
	"sort"
	"strings"

	"github.com/invopop/jsonschema"
	"golang.org/x/tools/go/packages"

	v6 "github.com/anchore/grype/grype/db/v6"
)

func main() {
	// The schema version is derived from the database version
	version := fmt.Sprintf("%d.%d.%d", v6.ModelVersion, v6.Revision, v6.Addition)

	pkgPatterns := []string{".."}
	comments := parseCommentsFromPackages(pkgPatterns)
	fmt.Printf("Extracted field comments from %d structs\n", len(comments))

	// Generate SQL schema
	if err := generateSQLSchema(version); err != nil {
		fmt.Printf("Failed to generate SQL schema: %v\n", err)
		os.Exit(1)
	}

	// Generate unified blob JSON schema
	err := generateBlobSchema(version, comments)
	if err != nil {
		fmt.Printf("Failed to generate blob JSON schema: %v\n", err)
		os.Exit(1)
	}
}

func generateSQLSchema(version string) error {
	// Create an in-memory database with all models
	db, err := v6.NewLowLevelDB("", true, true, false)
	if err != nil {
		return fmt.Errorf("failed to create database: %w", err)
	}

	sqlDB, err := db.DB()
	if err != nil {
		return fmt.Errorf("failed to get underlying database: %w", err)
	}
	defer sqlDB.Close()

	var schema strings.Builder
	schema.WriteString("-- Generated by grype/db/v6/schema\n")
	schema.WriteString("-- DO NOT EDIT: This file is auto-generated. Run 'task generate-db-schema' to update.\n")
	schema.WriteString(fmt.Sprintf("-- Schema version: %s\n\n", version))

	// Query for all tables and their CREATE statements
	createStatements, err := querySchemaSorted(sqlDB, "table")
	if err != nil {
		return err
	}

	for _, stmt := range createStatements {
		// Normalize the CREATE TABLE statement to ensure deterministic output
		normalized := normalizeCreateTable(stmt)
		schema.WriteString(normalized)
		schema.WriteString(";\n\n")
	}

	// Get all indexes
	indexStatements, err := querySchemaSorted(sqlDB, "index")
	if err != nil {
		return err
	}

	if len(indexStatements) > 0 {
		schema.WriteString("-- Indexes\n")
		for _, stmt := range indexStatements {
			schema.WriteString(stmt)
			schema.WriteString(";\n\n")
		}
	}

	return writeFile(schema.String(), "db/sql", version, ".sql")
}

func normalizeCreateTable(stmt string) string {
	// Sort CONSTRAINT clauses within CREATE TABLE statements for deterministic output
	// Foreign keys can appear in non-deterministic order from SQLite

	// Find the constraints section
	if !strings.Contains(stmt, "CONSTRAINT") {
		return stmt
	}

	// Split by CONSTRAINT keyword
	parts := strings.Split(stmt, "CONSTRAINT")
	if len(parts) <= 1 {
		return stmt
	}

	// First part contains everything before constraints
	prefix := parts[0]

	// Collect all constraints
	var constraints []string
	for i := 1; i < len(parts); i++ {
		constraints = append(constraints, "CONSTRAINT"+parts[i])
	}

	// Sort constraints
	sort.Strings(constraints)

	// Rebuild: need to handle the last column/field before constraints
	// The prefix ends with a comma, and each constraint except the last should have a comma
	result := strings.TrimRight(prefix, ",")
	for _, constraint := range constraints {
		result += ","
		// Remove trailing comma or closing paren from constraint
		constraint = strings.TrimRight(constraint, ",)")
		result += constraint
	}
	result += ")"

	return result
}

func querySchemaSorted(db *sql.DB, objectType string) ([]string, error) {
	// Use a placeholder '?' to prevent SQL injection warnings (gosec G201)
	query := `
		SELECT sql
		FROM sqlite_master
		WHERE type = ?
		AND name NOT LIKE 'sqlite_%%'
	`

	if objectType == "index" {
		query += " AND sql IS NOT NULL"
	}

	query += " ORDER BY name"

	// Pass the variable as a parameter to the query function
	rows, err := db.Query(query, objectType)
	if err != nil {
		return nil, fmt.Errorf("failed to query schema for type %s: %w", objectType, err)
	}
	defer rows.Close()

	var statements []string
	for rows.Next() {
		var sql string
		if err := rows.Scan(&sql); err != nil {
			return nil, fmt.Errorf("failed to scan schema: %w", err)
		}
		statements = append(statements, sql)
	}

	if err := rows.Err(); err != nil {
		return nil, fmt.Errorf("error iterating schema: %w", err)
	}

	// Sort for deterministic output
	sort.Strings(statements)
	return statements, nil
}

func generateBlobSchema(version string, comments map[string]map[string]string) error {
	// Create a unified schema that includes all blob types
	schema := buildUnifiedBlobSchema(version, comments)
	encoded := encode(schema)
	return writeFile(string(encoded), "db/blob/json", version, ".json")
}

func buildUnifiedBlobSchema(version string, comments map[string]map[string]string) *jsonschema.Schema {
	reflector := &jsonschema.Reflector{
		AllowAdditionalProperties: true,
		Namer: func(r reflect.Type) string {
			return strings.TrimPrefix(r.Name(), "JSON")
		},
	}

	// Reflect all three blob types to get their definitions
	vulnBlobSchema := reflector.ReflectFromType(reflect.TypeOf(v6.VulnerabilityBlob{}))
	packageBlobSchema := reflector.ReflectFromType(reflect.TypeOf(v6.PackageBlob{}))
	kevBlobSchema := reflector.ReflectFromType(reflect.TypeOf(v6.KnownExploitedVulnerabilityBlob{}))

	// Create the unified schema with oneOf
	unifiedSchema := &jsonschema.Schema{
		Version:     jsonschema.Version,
		ID:          jsonschema.ID(fmt.Sprintf("anchore.io/schema/grype/db/blob/json/%s", version)),
		Description: "Unified schema for all blob types stored in the Grype v6 database",
		OneOf: []*jsonschema.Schema{
			{Ref: "#/$defs/VulnerabilityBlob"},
			{Ref: "#/$defs/PackageBlob"},
			{Ref: "#/$defs/KnownExploitedVulnerabilityBlob"},
		},
		Definitions: make(map[string]*jsonschema.Schema),
	}

	// Merge all definitions from the three schemas
	mergeDefinitions(unifiedSchema.Definitions, vulnBlobSchema.Definitions)
	mergeDefinitions(unifiedSchema.Definitions, packageBlobSchema.Definitions)
	mergeDefinitions(unifiedSchema.Definitions, kevBlobSchema.Definitions)

	// Apply comments to the definitions
	applyComments(unifiedSchema.Definitions, comments)

	return unifiedSchema
}

func mergeDefinitions(target, source map[string]*jsonschema.Schema) {
	for k, v := range source {
		target[k] = v
	}
}

func applyComments(definitions map[string]*jsonschema.Schema, comments map[string]map[string]string) {
	for structName, fields := range comments {
		if structSchema, exists := definitions[structName]; exists {
			if structSchema.Definitions == nil {
				structSchema.Definitions = make(map[string]*jsonschema.Schema)
			}
			for fieldName, comment := range fields {
				if fieldName == "" {
					// struct-level comment
					structSchema.Description = comment
					continue
				}
				// field level comment
				if comment == "" {
					continue
				}
				if _, exists := structSchema.Properties.Get(fieldName); exists {
					fieldSchema, exists := structSchema.Definitions[fieldName]
					if exists {
						fieldSchema.Description = comment
					} else {
						fieldSchema = &jsonschema.Schema{
							Description: comment,
						}
					}
					structSchema.Definitions[fieldName] = fieldSchema
				}
			}
		}
	}
}

func encode(schema *jsonschema.Schema) []byte {
	newSchemaBuffer := new(bytes.Buffer)
	enc := json.NewEncoder(newSchemaBuffer)
	// prevent > and < from being escaped in the payload
	enc.SetEscapeHTML(false)
	enc.SetIndent("", "  ")
	err := enc.Encode(&schema)
	if err != nil {
		panic(err)
	}

	return newSchemaBuffer.Bytes()
}

func writeFile(content, component, version, extension string) error {
	parent := filepath.Join(repoRoot(), "schema", "grype", component)
	schemaPath := filepath.Join(parent, fmt.Sprintf("schema-%s%s", version, extension))
	latestSchemaPath := filepath.Join(parent, fmt.Sprintf("schema-latest%s", extension))

	// Create parent directory if it doesn't exist
	if err := os.MkdirAll(parent, 0o755); err != nil {
		return fmt.Errorf("unable to create schema directory: %w", err)
	}

	if _, err := os.Stat(schemaPath); !os.IsNotExist(err) {
		// check if the schema is the same...
		existingFh, err := os.Open(schemaPath)
		if err != nil {
			return err
		}
		defer existingFh.Close()

		existingBytes, err := io.ReadAll(existingFh)
		if err != nil {
			return err
		}

		if string(existingBytes) == content {
			// the generated schema is the same, bail with no error :)
			fmt.Printf("No change to the existing %q schema!\n", component)
			return nil
		}

		// the generated schema is different, bail with error :(
		fmt.Printf("Cowardly refusing to overwrite existing %q schema (%s)!\n", component, schemaPath)
		fmt.Printf("The schema has changed but the version has not been incremented.\n")
		fmt.Printf("See grype/db/v6/db.go to increment the ModelVersion, Revision, or Addition constants.\n")
		return fmt.Errorf("refusing to overwrite existing schema")
	}

	fh, err := os.Create(schemaPath)
	if err != nil {
		return err
	}
	defer fh.Close()

	if _, err = fh.WriteString(content); err != nil {
		return err
	}

	latestFile, err := os.Create(latestSchemaPath)
	if err != nil {
		return err
	}
	defer latestFile.Close()

	if _, err = latestFile.WriteString(content); err != nil {
		return err
	}

	fmt.Printf("Wrote new %q schema to %q\n", component, schemaPath)
	return nil
}

// parseCommentsFromPackages scans multiple packages and collects field comments for structs.
func parseCommentsFromPackages(pkgPatterns []string) map[string]map[string]string {
	commentMap := make(map[string]map[string]string)

	cfg := &packages.Config{
		Mode: packages.NeedFiles | packages.NeedSyntax | packages.NeedDeps | packages.NeedImports,
	}
	pkgs, err := packages.Load(cfg, pkgPatterns...)
	if err != nil {
		panic(fmt.Errorf("failed to load packages: %w", err))
	}

	for _, pkg := range pkgs {
		for _, file := range pkg.Syntax {
			fileComments := parseFileComments(file)
			for structName, fields := range fileComments {
				if _, exists := commentMap[structName]; !exists {
					commentMap[structName] = fields
				}
			}
		}
	}
	return commentMap
}

// parseFileComments extracts comments for structs and their fields in a single file.
func parseFileComments(node *ast.File) map[string]map[string]string {
	commentMap := make(map[string]map[string]string)

	ast.Inspect(node, func(n ast.Node) bool {
		ts, ok := n.(*ast.TypeSpec)
		if !ok {
			return true
		}
		st, ok := ts.Type.(*ast.StructType)
		if !ok {
			return true
		}

		structName := ts.Name.Name
		fieldComments := make(map[string]string)

		// extract struct-level comment
		if ts.Doc != nil {
			structComment := strings.TrimSpace(ts.Doc.Text())
			if !strings.Contains(structComment, "TODO:") {
				fieldComments[""] = cleanComment(structComment)
			}
		}

		// extract field-level comments
		for _, field := range st.Fields.List {
			if len(field.Names) == 0 {
				continue
			}
			fieldName := field.Names[0].Name
			jsonTag := getJSONTag(field)

			if field.Doc != nil {
				comment := strings.TrimSpace(field.Doc.Text())
				if strings.Contains(comment, "TODO:") {
					continue
				}
				if jsonTag != "" {
					fieldComments[jsonTag] = cleanComment(comment)
				} else {
					fieldComments[fieldName] = cleanComment(comment)
				}
			}
		}

		if len(fieldComments) > 0 {
			commentMap[structName] = fieldComments
		}
		return true
	})

	return commentMap
}

func cleanComment(comment string) string {
	// remove the first word, since that is the field name (if following go-doc patterns)
	split := strings.SplitN(comment, " ", 2)
	if len(split) > 1 {
		comment = split[1]
	}

	return strings.TrimSpace(strings.ReplaceAll(comment, "\"", "'"))
}

func getJSONTag(field *ast.Field) string {
	if field.Tag != nil {
		tagValue := strings.Trim(field.Tag.Value, "`")
		structTag := reflect.StructTag(tagValue)
		if jsonTag, ok := structTag.Lookup("json"); ok {
			jsonParts := strings.Split(jsonTag, ",")
			return strings.TrimSpace(jsonParts[0])
		}
	}
	return ""
}

func repoRoot() string {
	root, err := exec.Command("git", "rev-parse", "--show-toplevel").Output()
	if err != nil {
		panic(fmt.Errorf("unable to find repo root dir: %+v", err))
	}
	absRepoRoot, err := filepath.Abs(strings.TrimSpace(string(root)))
	if err != nil {
		panic(fmt.Errorf("unable to get abs path to repo root: %w", err))
	}
	return absRepoRoot
}
